{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Question answering (TensorFlow)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Install the Transformers, Datasets, and Evaluate libraries to run this notebook."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install datasets evaluate transformers[sentencepiece]\n",
    "!apt install git-lfs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You will need to setup git, adapt your email and name in the following cell."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!git config --global user.email \"you@example.com\"\n",
    "!git config --global user.name \"Your Name\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You will also need to be logged in to the Hugging Face Hub. Execute the following and enter your credentials."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from huggingface_hub import notebook_login\n",
    "\n",
    "notebook_login()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import load_dataset\n",
    "\n",
    "raw_datasets = load_dataset(\"squad\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DatasetDict({\n",
       "    train: Dataset({\n",
       "        features: ['id', 'title', 'context', 'question', 'answers'],\n",
       "        num_rows: 87599\n",
       "    })\n",
       "    validation: Dataset({\n",
       "        features: ['id', 'title', 'context', 'question', 'answers'],\n",
       "        num_rows: 10570\n",
       "    })\n",
       "})"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "raw_datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Context: 'Architecturally, the school has a Catholic character. Atop the Main Building\\'s gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend \"Venite Ad Me Omnes\". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive (and in a direct line that connects through 3 statues and the Gold Dome), is a simple, modern stone statue of Mary.'\n",
       "Question: 'To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France?'\n",
       "Answer: {'text': ['Saint Bernadette Soubirous'], 'answer_start': [515]}"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "print(\"Context: \", raw_datasets[\"train\"][0][\"context\"])\n",
    "print(\"Question: \", raw_datasets[\"train\"][0][\"question\"])\n",
    "print(\"Answer: \", raw_datasets[\"train\"][0][\"answers\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Dataset({\n",
       "    features: ['id', 'title', 'context', 'question', 'answers'],\n",
       "    num_rows: 0\n",
       "})"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "raw_datasets[\"train\"].filter(lambda x: len(x[\"answers\"][\"text\"]) != 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'text': ['Denver Broncos', 'Denver Broncos', 'Denver Broncos'], 'answer_start': [177, 177, 177]}\n",
       "{'text': ['Santa Clara, California', \"Levi's Stadium\", \"Levi's Stadium in the San Francisco Bay Area at Santa Clara, California.\"], 'answer_start': [403, 355, 355]}"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "print(raw_datasets[\"validation\"][0][\"answers\"])\n",
    "print(raw_datasets[\"validation\"][2][\"answers\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'Super Bowl 50 was an American football game to determine the champion of the National Football League (NFL) for the 2015 season. The American Football Conference (AFC) champion Denver Broncos defeated the National Football Conference (NFC) champion Carolina Panthers 24–10 to earn their third Super Bowl title. The game was played on February 7, 2016, at Levi\\'s Stadium in the San Francisco Bay Area at Santa Clara, California. As this was the 50th Super Bowl, the league emphasized the \"golden anniversary\" with various gold-themed initiatives, as well as temporarily suspending the tradition of naming each Super Bowl game with Roman numerals (under which the game would have been known as \"Super Bowl L\"), so that the logo could prominently feature the Arabic numerals 50.'\n",
       "'Where did Super Bowl 50 take place?'"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "print(raw_datasets[\"validation\"][2][\"context\"])\n",
    "print(raw_datasets[\"validation\"][2][\"question\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoTokenizer\n",
    "\n",
    "model_checkpoint = \"bert-base-cased\"\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenizer.is_fast"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'[CLS] To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France? [SEP] Architecturally, '\n",
       "'the school has a Catholic character. Atop the Main Building\\'s gold dome is a golden statue of the Virgin '\n",
       "'Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms '\n",
       "'upraised with the legend \" Venite Ad Me Omnes \". Next to the Main Building is the Basilica of the Sacred '\n",
       "'Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a '\n",
       "'replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette '\n",
       "'Soubirous in 1858. At the end of the main drive ( and in a direct line that connects through 3 statues '\n",
       "'and the Gold Dome ), is a simple, modern stone statue of Mary. [SEP]'"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "context = raw_datasets[\"train\"][0][\"context\"]\n",
    "question = raw_datasets[\"train\"][0][\"question\"]\n",
    "\n",
    "inputs = tokenizer(question, context)\n",
    "tokenizer.decode(inputs[\"input_ids\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'[CLS] To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France? [SEP] Architecturally, the school has a Catholic character. Atop the Main Building\\'s gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend \" Venite Ad Me Omnes \". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basi [SEP]'\n",
       "'[CLS] To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France? [SEP] the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend \" Venite Ad Me Omnes \". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin [SEP]'\n",
       "'[CLS] To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France? [SEP] Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive ( and in a direct line that connects through 3 [SEP]'\n",
       "'[CLS] To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France? [SEP]. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive ( and in a direct line that connects through 3 statues and the Gold Dome ), is a simple, modern stone statue of Mary. [SEP]'"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "inputs = tokenizer(\n",
    "    question,\n",
    "    context,\n",
    "    max_length=100,\n",
    "    truncation=\"only_second\",\n",
    "    stride=50,\n",
    "    return_overflowing_tokens=True,\n",
    ")\n",
    "\n",
    "for ids in inputs[\"input_ids\"]:\n",
    "    print(tokenizer.decode(ids))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "dict_keys(['input_ids', 'token_type_ids', 'attention_mask', 'offset_mapping', 'overflow_to_sample_mapping'])"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "inputs = tokenizer(\n",
    "    question,\n",
    "    context,\n",
    "    max_length=100,\n",
    "    truncation=\"only_second\",\n",
    "    stride=50,\n",
    "    return_overflowing_tokens=True,\n",
    "    return_offsets_mapping=True,\n",
    ")\n",
    "inputs.keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[0, 0, 0, 0]"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "inputs[\"overflow_to_sample_mapping\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'The 4 examples gave 19 features.'\n",
       "'Here is where each comes from: [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3].'"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "inputs = tokenizer(\n",
    "    raw_datasets[\"train\"][2:6][\"question\"],\n",
    "    raw_datasets[\"train\"][2:6][\"context\"],\n",
    "    max_length=100,\n",
    "    truncation=\"only_second\",\n",
    "    stride=50,\n",
    "    return_overflowing_tokens=True,\n",
    "    return_offsets_mapping=True,\n",
    ")\n",
    "\n",
    "print(f\"The 4 examples gave {len(inputs['input_ids'])} features.\")\n",
    "print(f\"Here is where each comes from: {inputs['overflow_to_sample_mapping']}.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "([83, 51, 19, 0, 0, 64, 27, 0, 34, 0, 0, 0, 67, 34, 0, 0, 0, 0, 0],\n",
       " [85, 53, 21, 0, 0, 70, 33, 0, 40, 0, 0, 0, 68, 35, 0, 0, 0, 0, 0])"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "answers = raw_datasets[\"train\"][2:6][\"answers\"]\n",
    "start_positions = []\n",
    "end_positions = []\n",
    "\n",
    "for i, offset in enumerate(inputs[\"offset_mapping\"]):\n",
    "    sample_idx = inputs[\"overflow_to_sample_mapping\"][i]\n",
    "    answer = answers[sample_idx]\n",
    "    start_char = answer[\"answer_start\"][0]\n",
    "    end_char = answer[\"answer_start\"][0] + len(answer[\"text\"][0])\n",
    "    sequence_ids = inputs.sequence_ids(i)\n",
    "\n",
    "    # Find the start and end of the context\n",
    "    idx = 0\n",
    "    while sequence_ids[idx] != 1:\n",
    "        idx += 1\n",
    "    context_start = idx\n",
    "    while sequence_ids[idx] == 1:\n",
    "        idx += 1\n",
    "    context_end = idx - 1\n",
    "\n",
    "    # If the answer is not fully inside the context, label is (0, 0)\n",
    "    if offset[context_start][0] > start_char or offset[context_end][1] < end_char:\n",
    "        start_positions.append(0)\n",
    "        end_positions.append(0)\n",
    "    else:\n",
    "        # Otherwise it's the start and end token positions\n",
    "        idx = context_start\n",
    "        while idx <= context_end and offset[idx][0] <= start_char:\n",
    "            idx += 1\n",
    "        start_positions.append(idx - 1)\n",
    "\n",
    "        idx = context_end\n",
    "        while idx >= context_start and offset[idx][1] >= end_char:\n",
    "            idx -= 1\n",
    "        end_positions.append(idx + 1)\n",
    "\n",
    "start_positions, end_positions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'Theoretical answer: the Main Building, labels give: the Main Building'"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "idx = 0\n",
    "sample_idx = inputs[\"overflow_to_sample_mapping\"][idx]\n",
    "answer = answers[sample_idx][\"text\"][0]\n",
    "\n",
    "start = start_positions[idx]\n",
    "end = end_positions[idx]\n",
    "labeled_answer = tokenizer.decode(inputs[\"input_ids\"][idx][start : end + 1])\n",
    "\n",
    "print(f\"Theoretical answer: {answer}, labels give: {labeled_answer}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'Theoretical answer: a Marian place of prayer and reflection, decoded example: [CLS] What is the Grotto at Notre Dame? [SEP] Architecturally, the school has a Catholic character. Atop the Main Building\\'s gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend \" Venite Ad Me Omnes \". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grot [SEP]'"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "idx = 4\n",
    "sample_idx = inputs[\"overflow_to_sample_mapping\"][idx]\n",
    "answer = answers[sample_idx][\"text\"][0]\n",
    "\n",
    "decoded_example = tokenizer.decode(inputs[\"input_ids\"][idx])\n",
    "print(f\"Theoretical answer: {answer}, decoded example: {decoded_example}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "max_length = 384\n",
    "stride = 128\n",
    "\n",
    "\n",
    "def preprocess_training_examples(examples):\n",
    "    questions = [q.strip() for q in examples[\"question\"]]\n",
    "    inputs = tokenizer(\n",
    "        questions,\n",
    "        examples[\"context\"],\n",
    "        max_length=max_length,\n",
    "        truncation=\"only_second\",\n",
    "        stride=stride,\n",
    "        return_overflowing_tokens=True,\n",
    "        return_offsets_mapping=True,\n",
    "        padding=\"max_length\",\n",
    "    )\n",
    "\n",
    "    offset_mapping = inputs.pop(\"offset_mapping\")\n",
    "    sample_map = inputs.pop(\"overflow_to_sample_mapping\")\n",
    "    answers = examples[\"answers\"]\n",
    "    start_positions = []\n",
    "    end_positions = []\n",
    "\n",
    "    for i, offset in enumerate(offset_mapping):\n",
    "        sample_idx = sample_map[i]\n",
    "        answer = answers[sample_idx]\n",
    "        start_char = answer[\"answer_start\"][0]\n",
    "        end_char = answer[\"answer_start\"][0] + len(answer[\"text\"][0])\n",
    "        sequence_ids = inputs.sequence_ids(i)\n",
    "\n",
    "        # Find the start and end of the context\n",
    "        idx = 0\n",
    "        while sequence_ids[idx] != 1:\n",
    "            idx += 1\n",
    "        context_start = idx\n",
    "        while sequence_ids[idx] == 1:\n",
    "            idx += 1\n",
    "        context_end = idx - 1\n",
    "\n",
    "        # If the answer is not fully inside the context, label is (0, 0)\n",
    "        if offset[context_start][0] > start_char or offset[context_end][1] < end_char:\n",
    "            start_positions.append(0)\n",
    "            end_positions.append(0)\n",
    "        else:\n",
    "            # Otherwise it's the start and end token positions\n",
    "            idx = context_start\n",
    "            while idx <= context_end and offset[idx][0] <= start_char:\n",
    "                idx += 1\n",
    "            start_positions.append(idx - 1)\n",
    "\n",
    "            idx = context_end\n",
    "            while idx >= context_start and offset[idx][1] >= end_char:\n",
    "                idx -= 1\n",
    "            end_positions.append(idx + 1)\n",
    "\n",
    "    inputs[\"start_positions\"] = start_positions\n",
    "    inputs[\"end_positions\"] = end_positions\n",
    "    return inputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(87599, 88729)"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_dataset = raw_datasets[\"train\"].map(\n",
    "    preprocess_training_examples,\n",
    "    batched=True,\n",
    "    remove_columns=raw_datasets[\"train\"].column_names,\n",
    ")\n",
    "len(raw_datasets[\"train\"]), len(train_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def preprocess_validation_examples(examples):\n",
    "    questions = [q.strip() for q in examples[\"question\"]]\n",
    "    inputs = tokenizer(\n",
    "        questions,\n",
    "        examples[\"context\"],\n",
    "        max_length=max_length,\n",
    "        truncation=\"only_second\",\n",
    "        stride=stride,\n",
    "        return_overflowing_tokens=True,\n",
    "        return_offsets_mapping=True,\n",
    "        padding=\"max_length\",\n",
    "    )\n",
    "\n",
    "    sample_map = inputs.pop(\"overflow_to_sample_mapping\")\n",
    "    example_ids = []\n",
    "\n",
    "    for i in range(len(inputs[\"input_ids\"])):\n",
    "        sample_idx = sample_map[i]\n",
    "        example_ids.append(examples[\"id\"][sample_idx])\n",
    "\n",
    "        sequence_ids = inputs.sequence_ids(i)\n",
    "        offset = inputs[\"offset_mapping\"][i]\n",
    "        inputs[\"offset_mapping\"][i] = [\n",
    "            o if sequence_ids[k] == 1 else None for k, o in enumerate(offset)\n",
    "        ]\n",
    "\n",
    "    inputs[\"example_id\"] = example_ids\n",
    "    return inputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(10570, 10822)"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "validation_dataset = raw_datasets[\"validation\"].map(\n",
    "    preprocess_validation_examples,\n",
    "    batched=True,\n",
    "    remove_columns=raw_datasets[\"validation\"].column_names,\n",
    ")\n",
    "len(raw_datasets[\"validation\"]), len(validation_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "small_eval_set = raw_datasets[\"validation\"].select(range(100))\n",
    "trained_checkpoint = \"distilbert-base-cased-distilled-squad\"\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(trained_checkpoint)\n",
    "eval_set = small_eval_set.map(\n",
    "    preprocess_validation_examples,\n",
    "    batched=True,\n",
    "    remove_columns=raw_datasets[\"validation\"].column_names,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "from transformers import TFAutoModelForQuestionAnswering\n",
    "\n",
    "eval_set_for_model = eval_set.remove_columns([\"example_id\", \"offset_mapping\"])\n",
    "eval_set_for_model.set_format(\"numpy\")\n",
    "\n",
    "batch = {k: eval_set_for_model[k] for k in eval_set_for_model.column_names}\n",
    "trained_model = TFAutoModelForQuestionAnswering.from_pretrained(trained_checkpoint)\n",
    "\n",
    "outputs = trained_model(**batch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "start_logits = outputs.start_logits.numpy()\n",
    "end_logits = outputs.end_logits.numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import collections\n",
    "\n",
    "example_to_features = collections.defaultdict(list)\n",
    "for idx, feature in enumerate(eval_set):\n",
    "    example_to_features[feature[\"example_id\"]].append(idx)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "n_best = 20\n",
    "max_answer_length = 30\n",
    "predicted_answers = []\n",
    "\n",
    "for example in small_eval_set:\n",
    "    example_id = example[\"id\"]\n",
    "    context = example[\"context\"]\n",
    "    answers = []\n",
    "\n",
    "    for feature_index in example_to_features[example_id]:\n",
    "        start_logit = start_logits[feature_index]\n",
    "        end_logit = end_logits[feature_index]\n",
    "        offsets = eval_set[\"offset_mapping\"][feature_index]\n",
    "\n",
    "        start_indexes = np.argsort(start_logit)[-1 : -n_best - 1 : -1].tolist()\n",
    "        end_indexes = np.argsort(end_logit)[-1 : -n_best - 1 : -1].tolist()\n",
    "        for start_index in start_indexes:\n",
    "            for end_index in end_indexes:\n",
    "                # Skip answers that are not fully in the context\n",
    "                if offsets[start_index] is None or offsets[end_index] is None:\n",
    "                    continue\n",
    "                # Skip answers with a length that is either < 0 or > max_answer_length.\n",
    "                if (\n",
    "                    end_index < start_index\n",
    "                    or end_index - start_index + 1 > max_answer_length\n",
    "                ):\n",
    "                    continue\n",
    "\n",
    "                answers.append(\n",
    "                    {\n",
    "                        \"text\": context[offsets[start_index][0] : offsets[end_index][1]],\n",
    "                        \"logit_score\": start_logit[start_index] + end_logit[end_index],\n",
    "                    }\n",
    "                )\n",
    "\n",
    "    best_answer = max(answers, key=lambda x: x[\"logit_score\"])\n",
    "    predicted_answers.append({\"id\": example_id, \"prediction_text\": best_answer[\"text\"]})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import evaluate\n",
    "\n",
    "metric = evaluate.load(\"squad\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "theoretical_answers = [\n",
    "    {\"id\": ex[\"id\"], \"answers\": ex[\"answers\"]} for ex in small_eval_set\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'id': '56be4db0acb8001400a502ec', 'prediction_text': 'Denver Broncos'}\n",
       "{'id': '56be4db0acb8001400a502ec', 'answers': {'text': ['Denver Broncos', 'Denver Broncos', 'Denver Broncos'], 'answer_start': [177, 177, 177]}}"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "print(predicted_answers[0])\n",
    "print(theoretical_answers[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'exact_match': 83.0, 'f1': 88.25}"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "metric.compute(predictions=predicted_answers, references=theoretical_answers)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm.auto import tqdm\n",
    "\n",
    "\n",
    "def compute_metrics(start_logits, end_logits, features, examples):\n",
    "    example_to_features = collections.defaultdict(list)\n",
    "    for idx, feature in enumerate(features):\n",
    "        example_to_features[feature[\"example_id\"]].append(idx)\n",
    "\n",
    "    predicted_answers = []\n",
    "    for example in tqdm(examples):\n",
    "        example_id = example[\"id\"]\n",
    "        context = example[\"context\"]\n",
    "        answers = []\n",
    "\n",
    "        # Loop through all features associated with that example\n",
    "        for feature_index in example_to_features[example_id]:\n",
    "            start_logit = start_logits[feature_index]\n",
    "            end_logit = end_logits[feature_index]\n",
    "            offsets = features[feature_index][\"offset_mapping\"]\n",
    "\n",
    "            start_indexes = np.argsort(start_logit)[-1 : -n_best - 1 : -1].tolist()\n",
    "            end_indexes = np.argsort(end_logit)[-1 : -n_best - 1 : -1].tolist()\n",
    "            for start_index in start_indexes:\n",
    "                for end_index in end_indexes:\n",
    "                    # Skip answers that are not fully in the context\n",
    "                    if offsets[start_index] is None or offsets[end_index] is None:\n",
    "                        continue\n",
    "                    # Skip answers with a length that is either < 0 or > max_answer_length\n",
    "                    if (\n",
    "                        end_index < start_index\n",
    "                        or end_index - start_index + 1 > max_answer_length\n",
    "                    ):\n",
    "                        continue\n",
    "\n",
    "                    answer = {\n",
    "                        \"text\": context[offsets[start_index][0] : offsets[end_index][1]],\n",
    "                        \"logit_score\": start_logit[start_index] + end_logit[end_index],\n",
    "                    }\n",
    "                    answers.append(answer)\n",
    "\n",
    "        # Select the answer with the best score\n",
    "        if len(answers) > 0:\n",
    "            best_answer = max(answers, key=lambda x: x[\"logit_score\"])\n",
    "            predicted_answers.append(\n",
    "                {\"id\": example_id, \"prediction_text\": best_answer[\"text\"]}\n",
    "            )\n",
    "        else:\n",
    "            predicted_answers.append({\"id\": example_id, \"prediction_text\": \"\"})\n",
    "\n",
    "    theoretical_answers = [{\"id\": ex[\"id\"], \"answers\": ex[\"answers\"]} for ex in examples]\n",
    "    return metric.compute(predictions=predicted_answers, references=theoretical_answers)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'exact_match': 83.0, 'f1': 88.25}"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "compute_metrics(start_logits, end_logits, eval_set, small_eval_set)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = TFAutoModelForQuestionAnswering.from_pretrained(model_checkpoint)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from huggingface_hub import notebook_login\n",
    "\n",
    "notebook_login()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import DefaultDataCollator\n",
    "\n",
    "data_collator = DefaultDataCollator(return_tensors=\"tf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tf_train_dataset = model.prepare_tf_dataset(\n",
    "    train_dataset,\n",
    "    collate_fn=data_collator,\n",
    "    shuffle=True,\n",
    "    batch_size=16,\n",
    ")\n",
    "tf_eval_dataset = model.prepare_tf_dataset(\n",
    "    validation_dataset,\n",
    "    collate_fn=data_collator,\n",
    "    shuffle=False,\n",
    "    batch_size=16,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import create_optimizer\n",
    "from transformers.keras_callbacks import PushToHubCallback\n",
    "import tensorflow as tf\n",
    "\n",
    "# The number of training steps is the number of samples in the dataset, divided by the batch size then multiplied\n",
    "# by the total number of epochs. Note that the tf_train_dataset here is a batched tf.data.Dataset,\n",
    "# not the original Hugging Face Dataset, so its len() is already num_samples // batch_size.\n",
    "num_train_epochs = 3\n",
    "num_train_steps = len(tf_train_dataset) * num_train_epochs\n",
    "optimizer, schedule = create_optimizer(\n",
    "    init_lr=2e-5,\n",
    "    num_warmup_steps=0,\n",
    "    num_train_steps=num_train_steps,\n",
    "    weight_decay_rate=0.01,\n",
    ")\n",
    "model.compile(optimizer=optimizer)\n",
    "\n",
    "# Train in mixed-precision float16\n",
    "tf.keras.mixed_precision.set_global_policy(\"mixed_float16\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers.keras_callbacks import PushToHubCallback\n",
    "\n",
    "callback = PushToHubCallback(output_dir=\"bert-finetuned-squad\", tokenizer=tokenizer)\n",
    "\n",
    "# We're going to do validation afterwards, so no validation mid-training\n",
    "model.fit(tf_train_dataset, callbacks=[callback], epochs=num_train_epochs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'exact_match': 81.18259224219489, 'f1': 88.67381321905516}"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "predictions = model.predict(tf_eval_dataset)\n",
    "compute_metrics(\n",
    "    predictions[\"start_logits\"],\n",
    "    predictions[\"end_logits\"],\n",
    "    validation_dataset,\n",
    "    raw_datasets[\"validation\"],\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'score': 0.9979003071784973,\n",
       " 'start': 78,\n",
       " 'end': 105,\n",
       " 'answer': 'Jax, PyTorch and TensorFlow'}"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from transformers import pipeline\n",
    "\n",
    "# Replace this with your own checkpoint\n",
    "model_checkpoint = \"huggingface-course/bert-finetuned-squad\"\n",
    "question_answerer = pipeline(\"question-answering\", model=model_checkpoint)\n",
    "\n",
    "context = \"\"\"\n",
    "🤗 Transformers is backed by the three most popular deep learning libraries — Jax, PyTorch and TensorFlow — with a seamless integration\n",
    "between them. It's straightforward to train your models with one before loading them for inference with the other.\n",
    "\"\"\"\n",
    "question = \"Which deep learning libraries back 🤗 Transformers?\"\n",
    "question_answerer(question=question, context=context)"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "name": "Question answering (TensorFlow)",
   "provenance": []
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
